{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "import input_data_vw as input_data\n",
    "import VGG\n",
    "import tools\n",
    "\n",
    "# 预训练好的npy模型\n",
    "pre_trained_weights = '../vgg16_pretrain/vgg16.npy'\n",
    "# 训练数据集合验证数据集的路径\n",
    "data_dir = \"/home/gps/HDD/dataset_awr/RTM_v4_LargeScale/trains/\"\n",
    "val_dir = \"/home/gps/HDD/dataset_awr/RTM_v4_LargeScale/test/\"\n",
    "# 训练日志路径\n",
    "train_log_dir = './logs/train//'\n",
    "val_log_dir = './logs//val//'\n",
    "\n",
    "# 训练\n",
    "def train():\n",
    "    IS_PRETRAIN = True\n",
    "    # 定义一些参数\n",
    "    N_CLASSES = 5\n",
    "    IMG_W = 208  # resize图片, 如果图片太大, 训练就很慢\n",
    "    IMG_H = 208\n",
    "    BATCH_SIZE = 16\n",
    "    CAPACITY = 2000 \n",
    "    learning_rate = 0.003\n",
    "    MAX_STEP = 15000   # 迭代的总步数 \n",
    "    # 读取数据集\n",
    "    with tf.name_scope('input'):\n",
    "        tra_image_batch, tra_label_batch = input_data.get_batch(data_dir,\n",
    "                                                          IMG_W,\n",
    "                                                          IMG_H,\n",
    "                                                          BATCH_SIZE, \n",
    "                                                          CAPACITY,\n",
    "                                                          N_CLASSES)\n",
    "        val_image_batch, val_label_batch = input_data.get_batch(val_dir, \n",
    "                                                          IMG_W,\n",
    "                                                          IMG_H,\n",
    "                                                          BATCH_SIZE, \n",
    "                                                          CAPACITY,\n",
    "                                                          N_CLASSES)\n",
    "\n",
    "    # 定义数据和标签的op\n",
    "    data_ph = tf.placeholder(tf.float32, shape=[BATCH_SIZE, IMG_W, IMG_H, 3])\n",
    "    label_ph = tf.placeholder(tf.int16, shape=[BATCH_SIZE, N_CLASSES])    \n",
    "    # 调用搭建好的网络\n",
    "    logits = VGG.VGG16N(tra_image_batch, N_CLASSES, IS_PRETRAIN)\n",
    "    # 定义损失函数的op\n",
    "    loss = tools.loss(logits, tra_label_batch)\n",
    "    # 定义准确率的op\n",
    "    accuracy = tools.accuracy(logits, tra_label_batch)\n",
    "    # 定义全局step的op\n",
    "    my_global_step = tf.Variable(0, name='global_step', trainable=False) \n",
    "    # 定义训练op\n",
    "    train_op = tools.optimize(loss, learning_rate, my_global_step)\n",
    "    # 定义存储op\n",
    "    saver = tf.train.Saver(tf.global_variables())\n",
    "    # 定义画图op\n",
    "    summary_op = tf.summary.merge_all()   \n",
    "    # 初始化所有tf的op\n",
    "    init = tf.global_variables_initializer()\n",
    "    # 定义一个tf会话的op\n",
    "    sess = tf.Session()\n",
    "    # 执行初始化\n",
    "    sess.run(init)\n",
    "    \n",
    "    # 直接加载预先训练好的vgg模型参数, 但跳过特定的层\n",
    "    tools.load_with_skip(pre_trained_weights, sess, ['fc6','fc7','fc8'])   \n",
    "\n",
    "    coord = tf.train.Coordinator()\n",
    "    threads = tf.train.start_queue_runners(sess=sess, coord=coord)    \n",
    "    tra_summary_writer = tf.summary.FileWriter(train_log_dir, sess.graph)\n",
    "    val_summary_writer = tf.summary.FileWriter(val_log_dir, sess.graph)\n",
    "    \n",
    "    try:\n",
    "        for step in np.arange(MAX_STEP):\n",
    "            if coord.should_stop():\n",
    "                break\n",
    "            # tra_image_batch, tra_label_batch实际上都是input_data输出的tensor数据\n",
    "            # 需要转换成numpy后才能feed给网络的train_op,loss_op和accuracy_op这些对象\n",
    "            tra_images,tra_labels = sess.run([tra_image_batch, tra_label_batch])\n",
    "            # 1.如果定义train_op, loss, accuracy这些op的时候未使用input_data输出的tensor数据作为这些op函数的输入，\n",
    "            # 则必须用feed来把numpy数据喂给tensor的placeholder\n",
    "            # 2.如果定义train_op, loss, accuracy这些op的时候使用了input_data输出的tensor数据作为这些op函数的输入，\n",
    "            # 就可以不用feed_dict，但是用了也可以为其指定输入，以免数据多次迭代（推荐）\n",
    "            _, tra_loss, tra_acc = sess.run([train_op, loss, accuracy],feed_dict={data_ph:tra_images, label_ph:tra_labels})\n",
    "            # 每隔20步，记录一次loss和acc\n",
    "            if step % 20 == 0 or (step + 1) == MAX_STEP:                 \n",
    "                print ('Step: %d, loss: %.4f, accuracy: %.4f%%' % (step, tra_loss, tra_acc))\n",
    "                summary_str = sess.run(summary_op)\n",
    "                tra_summary_writer.add_summary(summary_str, step)\n",
    "            # 每隔100步，用一批测试集的数据送入模型观察一下loss和acc   \n",
    "            if step % 100 == 0 or (step + 1) == MAX_STEP:\n",
    "                val_images, val_labels = sess.run([val_image_batch, val_label_batch])\n",
    "                val_loss, val_acc = sess.run([loss, accuracy],feed_dict={data_ph:val_images,label_ph:val_labels})\n",
    "                print('**  Step %d, val loss = %.2f, val accuracy = %.2f%%  **' %(step, val_loss, val_acc))\n",
    "                summary_str = sess.run(summary_op)\n",
    "                val_summary_writer.add_summary(summary_str, step)\n",
    "                    \n",
    "            if step % 500 == 0 or (step + 1) == MAX_STEP:\n",
    "                checkpoint_path = os.path.join(train_log_dir,\"model\")\n",
    "                saver.save(sess, checkpoint_path, global_step=step)\n",
    "                evaluate(step)\n",
    "                os.system(\"\")\n",
    "    except tf.errors.OutOfRangeError:\n",
    "        print('Done training -- epoch limit reached')\n",
    "    finally:\n",
    "        coord.request_stop()\n",
    "        \n",
    "    coord.join(threads)\n",
    "    sess.close()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    \n",
    "#%%   Test the accuracy on test dataset. got about 85.69% accuracy.\n",
    "import math\n",
    "def evaluate(train_step):\n",
    "    IS_PRETRAIN = False\n",
    "    with tf.Graph().as_default():\n",
    "\n",
    "        log_dir = 'logs/train/'\n",
    "        test_dir = val_dir\n",
    "        n_test = len(os.listdir(test_dir))\n",
    "        images, labels = input_data.get_batch(test_dir, IMG_W,IMG_H, BATCH_SIZE, CAPACITY,N_CLASSES)\n",
    "\n",
    "        logits = VGG.VGG16N(images, N_CLASSES, IS_PRETRAIN)\n",
    "        correct = tools.num_correct_prediction(logits, labels)\n",
    "        saver = tf.train.Saver(tf.global_variables())\n",
    "        \n",
    "        with tf.Session() as sess:\n",
    "            \n",
    "            print(\"Reading checkpoints...\")\n",
    "            model = \"logs/train/model-%d\" % train_step\n",
    "            if True:\n",
    "                saver.restore(sess, model)\n",
    "            else:\n",
    "                print('No checkpoint file found')\n",
    "                return\n",
    "        \n",
    "            coord = tf.train.Coordinator()\n",
    "            threads = tf.train.start_queue_runners(sess = sess, coord = coord)\n",
    "            \n",
    "            try:\n",
    "                print('\\nEvaluating......')\n",
    "                num_step = int(math.floor(n_test / BATCH_SIZE))\n",
    "                num_sample = num_step*BATCH_SIZE\n",
    "                step = 0\n",
    "                total_correct = 0\n",
    "                while step < num_step and not coord.should_stop():\n",
    "                    batch_correct = sess.run(correct)\n",
    "                    total_correct += np.sum(batch_correct)\n",
    "                    step += 1\n",
    "                print('Total testing samples: %d' %num_sample)\n",
    "                print('Total correct predictions: %d' %total_correct)\n",
    "                print('Average accuracy: %.2f%%' %(100*total_correct/num_sample))\n",
    "            except Exception as e:\n",
    "                coord.request_stop(e)\n",
    "            finally:\n",
    "                coord.request_stop()\n",
    "                coord.join(threads)\n",
    "#%%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step: 0, loss: 1.4936, accuracy: 12.5000%\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'x' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-93fd337a0d5c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-1-9ce8113613f3>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m     97\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m100\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mMAX_STEP\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     98\u001b[0m                 \u001b[0mval_images\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_labels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mval_image_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_label_batch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m                 \u001b[0mval_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mval_images\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my_\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mval_labels\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    100\u001b[0m                 \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'**  Step %d, val loss = %.2f, val accuracy = %.2f%%  **'\u001b[0m \u001b[0;34m%\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_acc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m                 \u001b[0msummary_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msummary_op\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'x' is not defined"
     ]
    }
   ],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
